697922
@@ -44,12 +44,11 @@
 import org.apache.calcite.rel.logical.LogicalAggregate;
 import org.apache.calcite.rel.logical.LogicalCorrelate;
 import org.apache.calcite.rel.logical.LogicalFilter;
-import org.apache.calcite.rel.logical.LogicalIntersect;
 import org.apache.calcite.rel.logical.LogicalJoin;
 import org.apache.calcite.rel.logical.LogicalProject;
-import org.apache.calcite.rel.logical.LogicalUnion;
 import org.apache.calcite.rel.metadata.RelMdUtil;
 import org.apache.calcite.rel.metadata.RelMetadataQuery;
+import org.apache.calcite.rel.rules.FilterCorrelateRule;
 import org.apache.calcite.rel.rules.FilterJoinRule;
 import org.apache.calcite.rel.rules.FilterProjectTransposeRule;
 import org.apache.calcite.rel.type.RelDataType;
@@ -62,7 +61,6 @@
 import org.apache.calcite.rex.RexInputRef;
 import org.apache.calcite.rex.RexLiteral;
 import org.apache.calcite.rex.RexNode;
-import org.apache.calcite.rex.RexOver;
 import org.apache.calcite.rex.RexShuttle;
 import org.apache.calcite.rex.RexSubQuery;
 import org.apache.calcite.rex.RexUtil;
@@ -87,11 +85,10 @@
 import org.apache.calcite.util.mapping.Mappings;
 import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate;
 import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveFilter;
-import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveIntersect;
 import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveJoin;
 import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject;
 import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveSortLimit;
-import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveUnion;
+import org.apache.hadoop.hive.ql.parse.SemanticAnalyzer;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -122,7 +119,6 @@
 import java.util.Objects;
 import java.util.Set;
 import java.util.SortedMap;
-import java.util.Stack;
 import java.util.TreeMap;
 import java.util.TreeSet;
 import javax.annotation.Nonnull;
@@ -181,8 +177,6 @@
 
   private final HashSet<LogicalCorrelate> generatedCorRels = Sets.newHashSet();
 
-  private Stack valueGen = new Stack();
-
   //~ Constructors -----------------------------------------------------------
 
   private HiveRelDecorrelator (
@@ -269,8 +263,6 @@
private RelNode decorrelate(RelNode root) {
       return planner2.findBestExp();
     }
 
-    assert(valueGen.isEmpty());
-
     return root;
   }
 
@@ -328,14 +320,8 @@
public RelNode removeCorrelationViaRule(RelNode root) {
     return planner.findBestExp();
   }
 
-  protected RexNode decorrelateExpr(RexNode exp, boolean valueGenerator) {
-    DecorrelateRexShuttle shuttle = new DecorrelateRexShuttle();
-    shuttle.setValueGenerator(valueGenerator);
-    return exp.accept(shuttle);
-  }
   protected RexNode decorrelateExpr(RexNode exp) {
     DecorrelateRexShuttle shuttle = new DecorrelateRexShuttle();
-    shuttle.setValueGenerator(false);
     return exp.accept(shuttle);
   }
 
@@ -1121,11 +1107,7 @@
private Frame decorrelateInputWithValueGenerator(RelNode rel) {
         try {
           findCorrelationEquivalent(correlation, ((Filter) rel).getCondition());
         } catch (Util.FoundOne e) {
-          // we need to keep predicate kind e.g. EQUAL or NOT EQUAL
-          // so that later while decorrelating LogicalCorrelate appropriate join predicate
-          // is generated
-          def.setPredicateKind((SqlKind)((Pair)e.getNode()).getValue());
-          map.put(def, (Integer)((Pair) e.getNode()).getKey());
+          map.put(def, (Integer) e.getNode());
         }
       }
       // If all correlation variables are now satisfied, skip creating a value
@@ -1164,22 +1146,16 @@
private Frame decorrelateInputWithValueGenerator(RelNode rel) {
   private void findCorrelationEquivalent(CorRef correlation, RexNode e)
           throws Util.FoundOne {
     switch (e.getKind()) {
-    // for now only EQUAL and NOT EQUAL corr predicates are optimized
-    case NOT_EQUALS:
-      if((boolean)valueGen.peek()) {
-        // we will need value generator
-        break;
-      }
-    case EQUALS:
+      case EQUALS:
         final RexCall call = (RexCall) e;
         final List<RexNode> operands = call.getOperands();
         if (references(operands.get(0), correlation)
                 && operands.get(1) instanceof RexInputRef) {
-          throw new Util.FoundOne(Pair.of(((RexInputRef) operands.get(1)).getIndex(), e.getKind()));
+          throw new Util.FoundOne(((RexInputRef) operands.get(1)).getIndex());
         }
         if (references(operands.get(1), correlation)
                 && operands.get(0) instanceof RexInputRef) {
-          throw new Util.FoundOne(Pair.of(((RexInputRef) operands.get(0)).getIndex(), e.getKind()));
+          throw new Util.FoundOne(((RexInputRef) operands.get(0)).getIndex());
         }
         break;
       case AND:
@@ -1248,22 +1224,17 @@
public Frame decorrelateRel(HiveFilter rel) throws SemanticException {
         return null;
       }
 
-      Frame oldInputFrame = frame;
       // If this LogicalFilter has correlated reference, create value generator
       // and produce the correlated variables in the new output.
       if (cm.mapRefRelToCorRef.containsKey(rel)) {
         frame = decorrelateInputWithValueGenerator(rel);
       }
 
-      boolean valueGenerator = true;
-      if(frame.r == oldInputFrame.r) {
-        // this means correated value generator wasn't generated
-        valueGenerator = false;
-      }
-        // Replace the filter expression to reference output of the join
-        // Map filter to the new filter over join
-        relBuilder.push(frame.r).filter(
-            (decorrelateExpr(rel.getCondition(), valueGenerator)));
+      // Replace the filter expression to reference output of the join
+      // Map filter to the new filter over join
+      relBuilder.push(frame.r).filter(
+          simplifyComparison(decorrelateExpr(rel.getCondition())));
+
       // Filter does not change the input ordering.
       // Filter rel does not permute the input.
       // All corvars produced by filter will have the same output positions in the
@@ -1273,6 +1244,39 @@
public Frame decorrelateRel(HiveFilter rel) throws SemanticException {
     }
   }
 
+  private RexNode simplifyComparison(RexNode op) {
+    switch(op.getKind()) {
+    case EQUALS:
+    case GREATER_THAN:
+    case GREATER_THAN_OR_EQUAL:
+    case LESS_THAN:
+    case LESS_THAN_OR_EQUAL:
+    case NOT_EQUALS:
+      RexCall e = (RexCall) op;
+      final List<RexNode> operands = new ArrayList<>(e.operands);
+
+      // Simplify "x <op> x"
+      final RexNode o0 = operands.get(0);
+      final RexNode o1 = operands.get(1);
+      // this should only be called when we are creating filter (decorrelate filter)
+      // since in that case null/unknown is treated as false we don't care about
+      // nullability of operands and will always rewrite op=op to op is not null
+      if (RexUtil.eq(o0, o1) )
+        switch (e.getKind()) {
+        case EQUALS:
+        case GREATER_THAN_OR_EQUAL:
+        case LESS_THAN_OR_EQUAL:
+          // "x = x" simplifies to "x is not null" (similarly <= and >=)
+          return rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, o0);
+        default:
+          // "x != x" simplifies to "false" (similarly < and >)
+          return rexBuilder.makeLiteral(false);
+        }
+    }
+    return op;
+  }
+
+
     /**
      * Rewrite LogicalFilter.
      *
@@ -1309,15 +1313,9 @@
public Frame decorrelateRel(LogicalFilter rel) {
 
     }
 
-    boolean valueGenerator = true;
-    if(frame.r == oldInput) {
-      // this means correated value generator wasn't generated
-      valueGenerator = false;
-    }
-
     // Replace the filter expression to reference output of the join
     // Map filter to the new filter over join
-    relBuilder.push(frame.r).filter(decorrelateExpr(rel.getCondition(), valueGenerator));
+    relBuilder.push(frame.r).filter(decorrelateExpr(rel.getCondition()));
 
 
     // Filter does not change the input ordering.
@@ -1347,9 +1345,6 @@
public Frame decorrelateRel(LogicalCorrelate rel) {
     final RelNode oldLeft = rel.getInput(0);
     final RelNode oldRight = rel.getInput(1);
 
-    boolean mightRequireValueGen = new findIfValueGenRequired().traverse(oldRight);
-    valueGen.push(mightRequireValueGen);
-
     final Frame leftFrame = getInvoke(oldLeft, rel);
     final Frame rightFrame = getInvoke(oldRight, rel);
 
@@ -1386,24 +1381,11 @@
public Frame decorrelateRel(LogicalCorrelate rel) {
       }
       final int newLeftPos = leftFrame.oldToNewOutputs.get(corDef.field);
       final int newRightPos = rightOutput.getValue();
-      if(corDef.getPredicateKind() == SqlKind.NOT_EQUALS) {
-        conditions.add(
-            rexBuilder.makeCall(SqlStdOperatorTable.NOT_EQUALS,
-                RexInputRef.of(newLeftPos, newLeftOutput),
-                new RexInputRef(newLeftFieldCount + newRightPos,
-                    newRightOutput.get(newRightPos).getType())));
-
-      }
-      else {
-        assert(corDef.getPredicateKind() == null
-          || corDef.getPredicateKind() == SqlKind.EQUALS);
-        conditions.add(
-            rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
-                RexInputRef.of(newLeftPos, newLeftOutput),
-                new RexInputRef(newLeftFieldCount + newRightPos,
-                    newRightOutput.get(newRightPos).getType())));
-
-      }
+      conditions.add(
+              rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
+                      RexInputRef.of(newLeftPos, newLeftOutput),
+                      new RexInputRef(newLeftFieldCount + newRightPos,
+                              newRightOutput.get(newRightPos).getType())));
 
       // remove this cor var from output position mapping
       corDefOutputs.remove(corDef);
@@ -1446,8 +1428,6 @@
public Frame decorrelateRel(LogicalCorrelate rel) {
             LogicalJoin.create(leftFrame.r, rightFrame.r, condition,
                     ImmutableSet.<CorrelationId>of(), rel.getJoinType().toJoinType());
 
-    valueGen.pop();
-
     return register(rel, newJoin, mapOldToNewOutputs, corDefOutputs);
   }
 
@@ -1840,66 +1820,7 @@
private static RelNode stripHep(RelNode rel) {
 
   /** Shuttle that decorrelates. */
   private class DecorrelateRexShuttle extends RexShuttle {
-    private boolean valueGenerator;
-    public void setValueGenerator(boolean valueGenerator) {
-      this.valueGenerator = valueGenerator;
-    }
-
-    // DecorrelateRexShuttle ends up decorrelating expressions cor.col1 <> $4
-    // to $4=$4 if value generator is not generated, $4<>$4 is further simplified
-    // to false. This is wrong and messes up the whole tree. To prevent this visitCall
-    // is overridden to rewrite/simply such predicates to is not null.
-    // we also need to take care that we do this only for correlated predicates and
-    // not user specified explicit predicates
-    // TODO:  This code should be removed once CALCITE-1851 is fixed and
-    // there is support of not equal
-    @Override  public RexNode visitCall(final RexCall call) {
-      if(!valueGenerator) {
-        switch (call.getKind()) {
-        case EQUALS:
-        case NOT_EQUALS:
-          final List<RexNode> operands = new ArrayList<>(call.operands);
-          RexNode o0 = operands.get(0);
-          RexNode o1 = operands.get(1);
-          boolean isCorrelated = false;
-          if (o0 instanceof RexFieldAccess && (cm.mapFieldAccessToCorRef.get(o0) != null)) {
-            o0 = decorrFieldAccess((RexFieldAccess) o0);
-            isCorrelated = true;
-
-          }
-          if (o1 instanceof RexFieldAccess && (cm.mapFieldAccessToCorRef.get(o1) != null)) {
-            o1 = decorrFieldAccess((RexFieldAccess) o1);
-            isCorrelated = true;
-          }
-          if (isCorrelated && RexUtil.eq(o0, o1)) {
-            return rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, o0);
-          }
-
-          final List<RexNode> newOperands = new ArrayList<>();
-          newOperands.add(o0);
-          newOperands.add(o1);
-          boolean[] update = { false };
-          List<RexNode> clonedOperands = visitList(newOperands, update);
-
-          return relBuilder.call(call.getOperator(), clonedOperands);
-        }
-      }
-      return super.visitCall(call);
-    }
-
     @Override public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
-      return decorrFieldAccess(fieldAccess);
-    }
-
-    @Override public RexNode visitInputRef(RexInputRef inputRef) {
-      final RexInputRef ref = getNewForOldInputRef(inputRef);
-      if (ref.getIndex() == inputRef.getIndex()
-              && ref.getType() == inputRef.getType()) {
-        return inputRef; // re-use old object, to prevent needless expr cloning
-      }
-      return ref;
-    }
-    private RexNode decorrFieldAccess(RexFieldAccess fieldAccess) {
       int newInputOutputOffset = 0;
       for (RelNode input : currentRel.getInputs()) {
         final Frame frame = map.get(input);
@@ -1914,7 +1835,7 @@
private RexNode decorrFieldAccess(RexFieldAccess fieldAccess) {
               // This input rel does produce the cor var referenced.
               // Assume fieldAccess has the correct type info.
               return new RexInputRef(newInputPos + newInputOutputOffset,
-                  frame.r.getRowType().getFieldList().get(newInputPos)
+                      frame.r.getRowType().getFieldList().get(newInputPos)
                       .getType());
             }
           }
@@ -1928,6 +1849,15 @@
private RexNode decorrFieldAccess(RexFieldAccess fieldAccess) {
       }
       return fieldAccess;
     }
+
+    @Override public RexNode visitInputRef(RexInputRef inputRef) {
+      final RexInputRef ref = getNewForOldInputRef(inputRef);
+      if (ref.getIndex() == inputRef.getIndex()
+              && ref.getType() == inputRef.getType()) {
+        return inputRef; // re-use old object, to prevent needless expr cloning
+      }
+      return ref;
+    }
   }
 
   /** Shuttle that removes correlations. */
@@ -2952,12 +2882,10 @@
public CorDef def() {
   static class CorDef implements Comparable<CorDef> {
     public final CorrelationId corr;
     public final int field;
-    private SqlKind predicateKind;
 
     CorDef(CorrelationId corr, int field) {
       this.corr = corr;
       this.field = field;
-      this.predicateKind = null;
     }
 
     @Override public String toString() {
@@ -2982,13 +2910,6 @@
public int compareTo(@Nonnull CorDef o) {
       }
       return Integer.compare(field, o.field);
     }
-    public SqlKind getPredicateKind() {
-      return predicateKind;
-    }
-    public void setPredicateKind(SqlKind predKind) {
-      this.predicateKind = predKind;
-
-    }
   }
 
   /** A map of the locations of
@@ -3066,107 +2987,6 @@
public boolean hasCorrelation() {
     }
   }
 
-  private static class findIfValueGenRequired extends HiveRelShuttleImpl {
-    private boolean mightRequireValueGen ;
-    findIfValueGenRequired() { this.mightRequireValueGen = true; }
-
-    private boolean hasRexOver(List<RexNode> projects) {
-      for(RexNode expr : projects) {
-        if(expr instanceof  RexOver) {
-          return true;
-        }
-      }
-      return false;
-    }
-    @Override public RelNode visit(HiveJoin rel) {
-      mightRequireValueGen = true;
-      return rel;
-    }
-    public RelNode visit(HiveSortLimit rel) {
-      mightRequireValueGen = true;
-      return rel;
-    }
-    public RelNode visit(HiveUnion rel) {
-      mightRequireValueGen = true;
-      return rel;
-    }
-    public RelNode visit(LogicalUnion rel) {
-      mightRequireValueGen = true;
-      return rel;
-    }
-    public RelNode visit(LogicalIntersect rel) {
-      mightRequireValueGen = true;
-      return rel;
-    }
-
-    public RelNode visit(HiveIntersect rel) {
-      mightRequireValueGen = true;
-      return rel;
-    }
-
-    @Override public RelNode visit(LogicalJoin rel) {
-      mightRequireValueGen = true;
-      return rel;
-    }
-    @Override public RelNode visit(HiveProject rel) {
-      if(!(hasRexOver(((HiveProject)rel).getProjects()))) {
-        mightRequireValueGen = false;
-        return super.visit(rel);
-      }
-      else {
-        mightRequireValueGen = true;
-        return rel;
-      }
-    }
-    @Override public RelNode visit(LogicalProject rel) {
-      if(!(hasRexOver(((LogicalProject)rel).getProjects()))) {
-        mightRequireValueGen = false;
-        return super.visit(rel);
-      }
-      else {
-        mightRequireValueGen = true;
-        return rel;
-      }
-    }
-    @Override public RelNode visit(HiveAggregate rel) {
-      // if there are aggregate functions or grouping sets we will need
-      // value generator
-      if((((HiveAggregate)rel).getAggCallList().isEmpty() == true
-          && ((HiveAggregate)rel).indicator == false)) {
-        this.mightRequireValueGen = false;
-        return super.visit(rel);
-      }
-      else {
-        // need to reset to true in case previous aggregate/project
-        // has set it to false
-        this.mightRequireValueGen = true;
-        return rel;
-      }
-    }
-    @Override public RelNode visit(LogicalAggregate rel) {
-      if((((LogicalAggregate)rel).getAggCallList().isEmpty() == true
-          && ((LogicalAggregate)rel).indicator == false)) {
-        this.mightRequireValueGen = false;
-        return super.visit(rel);
-      }
-      else {
-        // need to reset to true in case previous aggregate/project
-        // has set it to false
-        this.mightRequireValueGen = true;
-        return rel;
-      }
-    }
-    @Override public RelNode visit(LogicalCorrelate rel) {
-      // this means we are hitting nested subquery so don't
-      // need to go further
-      return rel;
-    }
-
-    public boolean traverse(RelNode root) {
-      root.accept(this);
-      return mightRequireValueGen;
-    }
-  }
   /** Builds a {@link org.apache.calcite.sql2rel.RelDecorrelator.CorelMap}. */
   private static class CorelMapBuilder extends HiveRelShuttleImpl {
     final SortedMap<CorrelationId, RelNode> mapCorToCorRel =
